from torch import nn
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import GPT2LMHeadModel

from ai.store_model import store_model_name, store_model


class GXLLMHeadModel(nn.Module):
    def __init__(self):
        super(GXLLMHeadModel, self).__init__()
        self.transformer = store_model.load_model(store_model_name.GXL_TRANSFORMER)
        self.lm_head = nn.Linear(768, 30000, bias=False)

    def forward(self, x):
        x = self.transformer(x)
        x = self.lm_head(x)
        return x


